package com.github.xzb617.client.flow.adapte.webflux.filter;

import com.github.xzb617.client.flow.adapte.webflux.util.QueryParamMapUtil;
import com.github.xzb617.client.flow.adapte.webflux.util.WebFluxUtil;
import com.github.xzb617.client.flow.cache.CounterCache;
import com.github.xzb617.client.flow.core.*;
import com.github.xzb617.client.flow.core.hotparam.HotParamFlowLimiter;
import com.github.xzb617.client.flow.core.rules.RulesFlowLimiter;
import com.github.xzb617.client.flow.core.statics.Statics;
import com.github.xzb617.client.flow.core.system.SystemProtectedFlowLimiter;
import com.github.xzb617.client.flow.fallback.FlowFallback;
import com.github.xzb617.client.flow.fallback.FlowResponse;
import com.github.xzb617.client.flow.utils.DateUtil;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.util.MultiValueMap;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;

public class FlowControlWebfluxFilter implements WebFilter {

    private final FlowProperties flowProperties;
    private final CounterCache counterCache;
    private final FlowFallback flowFallback;

    public FlowControlWebfluxFilter(FlowProperties flowProperties, CounterCache counterCache, FlowFallback flowFallback) {
        this.flowProperties = flowProperties;
        this.counterCache = counterCache;
        this.flowFallback = flowFallback;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange serverWebExchange, WebFilterChain webFilterChain) {
        ServerHttpRequest request = serverWebExchange.getRequest();

        // 统计QPS
        long currentTimestamp = DateUtil.getCurrentTimestamp();
        Statics.addTotal(currentTimestamp);

        // 组合FlowContext
        long qps = Statics.getCurrentQPS(currentTimestamp);
        String ip = WebFluxUtil.getClientIpAddr(request);
        String uri = request.getURI().getPath();
        MultiValueMap<String, String> queryParams = request.getQueryParams();
        Map<String, String[]> paramMap = QueryParamMapUtil.toMapOfArray(queryParams);

        FlowContext flowContext = new FlowContext(qps, ip, uri, paramMap);

        // 执行Chain
        FlowLimiterChain chain = this.buildLimiterChain();
        try {
            chain.execute(flowContext, flowProperties, counterCache, chain);
        } catch (FlowException fe) {
            // 捕获限流异常
            FlowResponse flowResponse = this.flowFallback.fallback(fe);
            ServerHttpResponse response = serverWebExchange.getResponse();
            response.setStatusCode(HttpStatus.resolve(flowResponse.getResponseStatus()));
            DataBuffer dataBuffer = response.bufferFactory().wrap(flowResponse.getResponseContent().getBytes());
            return response.writeWith(Mono.just(dataBuffer));
        }

        // 如果执行到这一步没有抛出限流异常，则代表成功通过
        Statics.addPass(currentTimestamp);

        return webFilterChain.filter(serverWebExchange);
    }

    /**
     * 构建限流器链
     * @return
     */
    private FlowLimiterChain buildLimiterChain() {
        List<FlowLimiter> flowLimiters = new ArrayList<>();
        flowLimiters.add(new SystemProtectedFlowLimiter());
        flowLimiters.add(new HotParamFlowLimiter());
        flowLimiters.add(new RulesFlowLimiter());
        return new FlowLimiterChain(flowLimiters);
    }
}
